package main

import (
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha256"
	"errors"
	"fmt"
	"net/url"
	"os"

	concourseCmd "github.com/concourse/concourse/cmd"

	"github.com/concourse/concourse/atc/atccmd"
	"github.com/concourse/concourse/tsa/tsacmd"
	"github.com/concourse/flag/v2"
	"github.com/jessevdk/go-flags"
	"github.com/tedsuo/ifrit"
	"github.com/tedsuo/ifrit/grouper"
	"github.com/tedsuo/ifrit/sigmon"
)

type WebCommand struct {
	PeerAddress string `long:"peer-address" default:"127.0.0.1" description:"Network address of this web node, reachable by other web nodes. Used for forwarded worker addresses."`

	*atccmd.RunCommand
	*tsacmd.TSACommand `group:"TSA Configuration" namespace:"tsa"`
}

func (WebCommand) LessenRequirements(command *flags.Command) {
	// defaults to atc external URL
	command.FindOptionByLongName("tsa-atc-url").Required = false
	command.FindOptionByLongName("tsa-token-url").Required = false

	// generated by default
	command.FindOptionByLongName("session-signing-key").Required = false

	// defaults are derived from the signing key
	command.FindOptionByLongName("client-secret").Required = false
	command.FindOptionByLongName("tsa-client-secret").Required = false
}

func (cmd *WebCommand) Execute(args []string) error {
	runner, err := cmd.Runner(args)
	if err != nil {
		return err
	}

	return <-ifrit.Invoke(sigmon.New(runner)).Wait()
}

func (cmd *WebCommand) Runner(args []string) (ifrit.Runner, error) {
	if cmd.RunCommand.CLIArtifactsDir == "" {
		cmd.RunCommand.CLIArtifactsDir = flag.Dir(concourseCmd.DiscoverAsset("fly-assets"))
	}

	err := cmd.populateSharedFlags()
	if err != nil {
		return nil, err
	}

	atcRunner, err := cmd.RunCommand.Runner(args)
	if err != nil {
		return nil, err
	}

	tsaRunner, err := cmd.TSACommand.Runner(args)
	if err != nil {
		return nil, err
	}

	logger, _ := cmd.RunCommand.Logger.Logger("web")
	return grouper.NewParallel(os.Interrupt, grouper.Members{
		{
			Name:   "atc",
			Runner: concourseCmd.NewLoggingRunner(logger.Session("atc-runner"), atcRunner),
		},
		{
			Name:   "tsa",
			Runner: concourseCmd.NewLoggingRunner(logger.Session("tsa-runner"), tsaRunner),
		},
	}), nil
}

func (cmd *WebCommand) populateSharedFlags() error {
	var signingKey *rsa.PrivateKey
	if cmd.RunCommand.Auth.AuthFlags.SigningKey == nil || cmd.RunCommand.Auth.AuthFlags.SigningKey.PrivateKey == nil {
		var err error
		signingKey, err = rsa.GenerateKey(rand.Reader, 2048)
		if err != nil {
			return fmt.Errorf("failed to generate session signing key: %s", err)
		}

		cmd.RunCommand.Auth.AuthFlags.SigningKey = &flag.PrivateKey{PrivateKey: signingKey}
	} else {
		signingKey = cmd.RunCommand.Auth.AuthFlags.SigningKey.PrivateKey
	}

	cmd.TSACommand.PeerAddress = cmd.PeerAddress

	if len(cmd.TSACommand.ATCURLs) == 0 {
		cmd.TSACommand.ATCURLs = []flag.URL{cmd.RunCommand.DefaultURL()}
	}

	if cmd.TSACommand.TokenURL.URL == nil {
		tokenPath, _ := url.Parse("/sky/issuer/token")
		cmd.TSACommand.TokenURL.URL = cmd.RunCommand.DefaultURL().URL.ResolveReference(tokenPath)
	}

	if cmd.TSACommand.ClientSecret == "" {
		cmd.TSACommand.ClientSecret = derivedCredential(signingKey, cmd.TSACommand.ClientID)
	}

	if cmd.RunCommand.Server.ClientSecret == "" {
		cmd.RunCommand.Server.ClientSecret = derivedCredential(signingKey, cmd.RunCommand.Server.ClientID)
	}

	cmd.RunCommand.Auth.AuthFlags.Clients[cmd.TSACommand.ClientID] = cmd.TSACommand.ClientSecret
	cmd.RunCommand.Auth.AuthFlags.Clients[cmd.RunCommand.Server.ClientID] = cmd.RunCommand.Server.ClientSecret

	// if we're using the 'aud' as the SystemClaimKey then we want to validate
	// that the SystemClaimValues contains our TSA Client. If it's not 'aud' then
	// we can't validate anything
	if cmd.RunCommand.SystemClaimKey == "aud" {

		// if we're using the default SystemClaimValues then override these values
		// to make sure they include the TSA ClientID
		if len(cmd.RunCommand.SystemClaimValues) == 1 {
			if cmd.RunCommand.SystemClaimValues[0] == "concourse-worker" {
				cmd.RunCommand.SystemClaimValues = []string{cmd.TSACommand.ClientID}
			}
		}

		if err := cmd.validateSystemClaimValues(); err != nil {
			return err
		}
	}

	cmd.TSACommand.ClusterName = cmd.RunCommand.Server.ClusterName
	cmd.TSACommand.LogClusterName = cmd.RunCommand.LogClusterName

	return nil
}

func (cmd *WebCommand) validateSystemClaimValues() error {

	found := false
	for _, val := range cmd.RunCommand.SystemClaimValues {
		if val == cmd.TSACommand.ClientID {
			found = true
		}
	}

	if !found {
		return errors.New("at least one systemClaimValue must be equal to tsa-client-id")
	}

	return nil
}

func derivedCredential(key *rsa.PrivateKey, clientID string) string {
	return fmt.Sprintf("%x", sha256.Sum256(key.N.Append([]byte(clientID), 10)))
}
